{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting Finetune_SGD.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile Finetune_SGD.py\n",
    "from __future__ import print_function, division\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms\n",
    "import matplotlib.pyplot as pltD\n",
    "import time\n",
    "import copy\n",
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append('General_utils')\n",
    "\n",
    "from ImageFolderTrainVal import *\n",
    "from test_network import *\n",
    "from SGD_Training import *\n",
    "\n",
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appending to Finetune_SGD.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile -a Finetune_SGD.py\n",
    "\n",
    "\n",
    "def exp_lr_scheduler(optimizer, epoch, init_lr=0.0008, lr_decay_epoch=45):\n",
    "    \"\"\"Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.\"\"\"\n",
    "    lr = init_lr * (0.1**(epoch // lr_decay_epoch))\n",
    "    print('lr is '+str(lr))\n",
    "    if epoch % lr_decay_epoch == 0:\n",
    "        print('LR is set to {}'.format(lr))\n",
    "\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = lr\n",
    "\n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Appending to Finetune_SGD.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile -a Finetune_SGD.py\n",
    "\n",
    "def fine_tune_SGD(dataset_path,model_path,exp_dir,batch_size=100, num_epochs=100,lr=0.0004,init_freeze=1):\n",
    "   \n",
    "    print('lr is ' + str(lr))\n",
    "    \n",
    "    dsets = torch.load(dataset_path)\n",
    "    dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=batch_size,\n",
    "                                               shuffle=True, num_workers=4)\n",
    "                for x in ['train', 'val']}\n",
    "    dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']}\n",
    "    dset_classes = dsets['train'].classes\n",
    "\n",
    "    use_gpu = torch.cuda.is_available()\n",
    "    resume=os.path.join(exp_dir,'epoch.pth.tar')\n",
    "    if os.path.isfile(resume):\n",
    "            checkpoint = torch.load(resume)\n",
    "            model_ft = checkpoint['model']\n",
    "    if not os.path.isfile(model_path):\n",
    "        model_ft = models.alexnet(pretrained=True)\n",
    "       \n",
    "    else:\n",
    "        model_ft=torch.load(model_path)\n",
    "    if not init_freeze:    \n",
    "        num_ftrs = model_ft.classifier[6].in_features \n",
    "        model_ft.classifier._modules['6'] = nn.Linear(num_ftrs, len(dset_classes))    \n",
    "    if not os.path.exists(exp_dir):\n",
    "        os.makedirs(exp_dir)\n",
    "    if use_gpu:\n",
    "        model_ft = model_ft.cuda()\n",
    "\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "    \n",
    "    optimizer_ft =  optim.SGD(model_ft.parameters(), lr, momentum=0.9)\n",
    "\n",
    "        \n",
    "    \n",
    "  \n",
    "    model_ft = train_model(model_ft, criterion, optimizer_ft,exp_lr_scheduler,lr, dset_loaders,dset_sizes,use_gpu,num_epochs,exp_dir,resume)\n",
    "    \n",
    "    return model_ft\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
